package syncx

import (
	"hash/fnv"
	"sync"

	"github.com/spaolacci/murmur3"
	"github.com/zeromicro/go-zero/core/logx"
)

// ShardedLock 包含一个分片锁数组
type ShardedLock struct {
	shardCount int
	shards     []sync.Mutex
}

// NewShardedLock 创建一个新的分片锁
func NewShardedLock(shardCount int) *ShardedLock {
	return &ShardedLock{
		shardCount: shardCount,
		shards:     make([]sync.Mutex, shardCount),
	}
}

// Lock 锁定指定的分片
func (s *ShardedLock) Lock(shard int) {
	s.shards[shard].Lock()
}

// Unlock 解锁指定的分片
func (s *ShardedLock) Unlock(shard int) {
	s.shards[shard].Unlock()
}

func (s *ShardedLock) LockByKey(key string) {
	shard := s.GetHashShard(key)
	logx.Infof("ShardedLock lock key:%s shard:%d", key, shard)
	s.shards[shard].Lock()
}

func (s *ShardedLock) UnlockByKey(key string) {
	shard := s.GetHashShard(key)
	logx.Infof("ShardedLock unlock key:%s shard:%d", key, shard)
	s.shards[shard].Unlock()
}

func (s *ShardedLock) GetHashShard(key string) int {
	hasher := fnv.New32()
	hasher.Write([]byte(key))
	hashVal := hasher.Sum32()
	shardID := hashVal % uint32(s.shardCount)
	return int(shardID)
}

// GetHashShardV2 计算 key 对应的分片，优化了分散度
func (s *ShardedLock) GetHashShardV2(key string) int {
	// 使用 MurmurHash3 计算哈希值
	hasher := murmur3.New32()
	hasher.Write([]byte(key))
	hashVal := hasher.Sum32()

	// 加入扰动计算以减少偏差
	hashVal = hashVal ^ (hashVal >> 16)

	// 取模，计算分片
	shardID := hashVal % uint32(s.shardCount)
	return int(shardID)
}
